// client.cpp
#define _WIN32_DCOM
#include <iostream.h>
#include <stdio.h>
#include <time.h>
#include <stdlib.h>
#include "Component\component.h"

class CMessageFilter : public IMessageFilter
{
public:
	// IUnknown
	ULONG __stdcall AddRef();
	ULONG __stdcall Release();
	HRESULT __stdcall QueryInterface(REFIID riid, void** ppv);

	// IMessageFilter
    DWORD __stdcall HandleInComingCall(DWORD dwCallType, HTASK htaskCaller, DWORD dwTickCount, LPINTERFACEINFO lpInterfaceInfo);
    DWORD __stdcall RetryRejectedCall(HTASK htaskCallee, DWORD dwTickCount, DWORD dwRejectType);
    DWORD __stdcall MessagePending(HTASK htaskCallee, DWORD dwTickCount, DWORD dwPendingType);
};

ULONG CMessageFilter::AddRef()
{
	return 2;
}

ULONG CMessageFilter::Release()
{
	return 1;
}

HRESULT CMessageFilter::QueryInterface(REFIID riid, void **ppv)
{
	if(riid == IID_IUnknown || riid == IID_IMessageFilter)
		*ppv = (IMessageFilter*)this;
	else
	{
		*ppv = NULL;
		return E_NOINTERFACE;
	}
	return S_OK;
}

DWORD CMessageFilter::HandleInComingCall(DWORD dwCallType, HTASK htaskCaller, DWORD dwTickCount, LPINTERFACEINFO lpInterfaceInfo)
{
	return E_NOTIMPL;
}

DWORD CMessageFilter::RetryRejectedCall(HTASK htaskCallee, DWORD dwTickCount, DWORD dwRejectType)
{
	cout << "Client: CMessageFilter::RetryRejectedCall" << endl;
	static counter = 0;
	static int randa = (int)((((float)rand())/RAND_MAX)*10);
	cout << "randa = " << randa << endl;
	if(counter++ < randa)
		return 500;
	return -1;
}

DWORD CMessageFilter::MessagePending(HTASK htaskCallee, DWORD dwTickCount, DWORD dwPendingType)
{
	cout << "Client: CMessageFilter::MessagePending" << endl;
	return PENDINGMSG_WAITNOPROCESS;
}

void ErrorMessage(char* szMessage, HRESULT hr)
{
    if(HRESULT_FACILITY(hr) == FACILITY_WINDOWS)
        hr = HRESULT_CODE(hr);

    char* szError;
    if(FormatMessage(FORMAT_MESSAGE_ALLOCATE_BUFFER|
        FORMAT_MESSAGE_FROM_SYSTEM, NULL, hr, 
        MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), 
        (LPTSTR)&szError, 0, NULL) != 0)
    {
        printf("%s: (%0x) %s", szMessage, hr, szError);
        LocalFree(szError);
    }
}

void main()
{
	srand(GetTickCount());

	cout << "Client: Calling CoInitialize()" << endl;
	CoInitialize(NULL);

	IMessageFilter* pMF = new CMessageFilter;
	IMessageFilter* pOldMF;
	CoRegisterMessageFilter(pMF, &pOldMF);

	IUnknown* pUnknown;
	ISum* pSum;

	cout << "Client: Calling CoCreateInstance() " << endl;
	CoCreateInstance(CLSID_InsideCOM, NULL, CLSCTX_LOCAL_SERVER, IID_IUnknown, (void**)&pUnknown);

	cout << "Client: Calling QueryInterface() for ISum on " << pUnknown << endl;
	HRESULT hr = pUnknown->QueryInterface(IID_ISum, (void**)&pSum);

	if(FAILED(hr))
		cout << "QueryInterface FAILED" << endl;

	cout << "Client: Calling Release() for pUnknown" << endl;
	hr = pUnknown->Release();

	cout << "Client: pSum = " << pSum << endl;

	int sum;
	hr = pSum->Sum(4, 9, &sum);
	if(SUCCEEDED(hr))
		cout << "Client: Calling Sum() return value is " << sum << endl;
	else
		ErrorMessage("Sum", hr);

	cout << "Client: Calling Release() for pSum" << endl;
	hr = pSum->Release();

	cout << "Client: Calling CoUninitialize()" << endl;
	CoUninitialize();
}